import torch
import pandas as pd
import numpy as np
import re
import ast
import json
import os
import ast
from openai import OpenAI
from tqdm import tqdm

class NumericalFalsePresuppositionEvaluator:
    def __init__(self, client):
        self.client = client

    def _response_parse_llm(self, response):

        parsing_prompt = f"""
        You are given a text . Your task is to list the entities in the above text? Just give the list separated by commas
        Text: {response}

        Return your generation in the following format. Do not include any other text:

        list: [your parsed list here]

        """

        completion = self.client.chat.completions.create(
            model="gpt-4.1-mini",
            messages=[
                {"role": "system", "content": "You are a helpful assistant that extracts entity lists from text. Always respond with just a comma-separated list of entities, nothing else."},
                {"role": "user", "content": parsing_prompt},
            ],
        )

        try:
            generation = completion.choices[0].message.content.strip().split("list: ")[1]
        except:
            generation = completion.choices[0].message.content.strip()
        
        return generation


    def _extract_atomic_units(self, text):
        if ', ' in text:
            atom = text.split(', ')
            return atom
        else:
            return []

    def evaluate(self, filename, json_input, output_dir="./hlc", data=None):
        output_list = []

        print(f"Evaluating {filename}...")

        with open(json_input, 'r', encoding='utf-8') as f:
            json_input_ = json.load(f)

        for obj in tqdm(json_input_):
            prompt = obj['Prompt']
            responses = obj['Responses']
            for response in responses:

                judge_prompt = f"""
                You are given a prompt and its response . Your task is to judge whether the response to the prompt is an abstention from answering or not? Just answer with 'yes' or 'no' or 'partially'. 'yes' if it is an abstention, 'no' if it is not an abstention and it seems like an answer, and 'partially' IFF there are multiple items and some of them are abstentions.
                prompt: {prompt}
                response: {response}

                Return your generation in the following format. Do not include any other text:

                abstention: [your judgement here]

                """

                completion = self.client.chat.completions.create(
                    model="gpt-4.1-mini",
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant that judges a pair of prompts and responses and checks whether the response is an absention from answering the prompt. Always respond with just yes, no, or partially, nothing else."},
                        {"role": "user", "content": judge_prompt},
                    ],
                )
                try:
                    abstention = completion.choices[0].message.content.strip().split("abstention: ")[1]
                except:
                    abstention = completion.choices[0].message.content.strip()

                abstention = abstention.strip().lower()

                if abstention == "yes":
                    parsed_response = ""
                    atomic_units = []
                    hallucinated_atomic_units = []
                    list_condition = []
                elif abstention == "partially" or abstention == "no":
                    parsed_response = self._response_parse_llm(response)
                    atomic_units = self._extract_atomic_units(parsed_response)

                    def safe_eval(val):
                        if isinstance(val, str):
                            return ast.literal_eval(val)
                        return val

                    data['list_condition'] = data['list_condition'].apply(safe_eval)

                    result = data.loc[data['prompt'] == prompt, 'list_condition']
                    list_condition = result.iloc[0] if not result.empty else []

                    hallucinated_atomic_units = list(
                        set(str(item).lower().strip() for item in atomic_units) -
                        set(str(item).lower().strip() for item in list_condition)
                    )
                
                output_list.append({
                    prompt: {
                        "Response": response,
                        "parsed_response": parsed_response,
                        "atomic_units": atomic_units,
                        "list_condition": list_condition,
                        "hallucinated_atomic_units": hallucinated_atomic_units
                    }
                })

        output_file_path = os.path.join(output_dir, filename)
        with open(output_file_path, "w", encoding="utf-8") as f:
            json.dump(output_list, f, ensure_ascii=False, indent=2)
        return output_file_path